In [56]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [ ]:
import argparse
import os
import sys
import numpy as np
import torch
from matplotlib import pyplot as plt
import pandas as pd
In [61]:
module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src import constants as c
from src import utils
from src import visualization as v
from src import model as m
In [ ]:
parser = utils.setup_argparse()
args = parser.parse_args(args=['--root=/users/dli44/tool-presence/',
                               '--data-dir=data/surgical_data/',
                               '--image-size=64',
                               '--loss-function=mmd'
                              ])
In [ ]:
datasets, dataloaders = utils.setup_data(args, augmentation=False)
In [ ]:
load_model = True
model_name = "mmd_beta_1.0_epoch_50.torch"
model_path = os.path.join(args.root, 'data/mmd_vae', model_name)
In [ ]:
model = m.VAE(image_channels=args.image_channels,
              image_size=args.image_size,
              h_dim1=1024,
              h_dim2=128,
              zdim=args.z_dim).to(c.device)
In [ ]:
model.load_state_dict(torch.load(model_path))
In [ ]:
labels = pd.read_csv(os.path.join(args.root, args.data_dir, 'surgical_labels.csv'))
In [ ]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([datasets['val'][1][0].numpy().transpose(1,2,0), 
                      datasets['val'][9][0].numpy().transpose(1,2,0)]))
In [ ]:
fig = plt.figure()
recon1, z, _, _ = model(datasets['val'][1][0].unsqueeze(0).to(c.device))
recon2, z, _, _ = model(datasets['val'][9][0].unsqueeze(0).to(c.device))

recon1 = utils.torch_to_image(recon1)
recon2 = utils.torch_to_image(recon2)

originals = np.hstack([utils.torch_to_image(datasets['val'][1][0]), 
                       utils.torch_to_image(datasets['val'][9][0])])
recons = np.hstack([recon1, recon2])

plt.imshow(np.vstack([originals, recons]))
In [74]:
images = v.latent_interpolation(datasets['val'][1][0], 
                                datasets['val'][9][0], 
                                model=model)

fig = v.plot_interpolation(images, "Interpolation\nBeta=5")

plt.savefig(os.path.join(args.root,
                         'data/mmd_vae',
                         'mmd_tool_motion.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)
In [65]:
a = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][1][0], model))[0]
b = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][9][0], model))[0]
diff = a-b
In [66]:
fig = plt.figure()
plt.plot(a)
plt.plot(b)
Out[66]:
[<matplotlib.lines.Line2D at 0x7ffa153cf6d8>]
In [67]:
fig = plt.figure()
plt.plot(a-b)
Out[67]:
[<matplotlib.lines.Line2D at 0x7ffa153b7198>]
In [81]:
for zdim in range(64):
    images = v.explore_latent_dimension(datasets['val'][1][0], model, zdim=zdim)
    fig = v.plot_interpolation(images, title='zdim {}'.format(zdim))
In [77]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([datasets['train'][360][0].numpy().transpose(1,2,0), 
                      datasets['train'][368][0].numpy().transpose(1,2,0)]))
Out[77]:
<matplotlib.image.AxesImage at 0x7ffa02af77f0>
In [78]:
images = v.latent_interpolation(datasets['train'][360][0], 
                                datasets['train'][368][0], 
                                model=model)

fig = v.plot_interpolation(images, "Interpolation\nBeta=5")

plt.savefig(os.path.join(args.root,
                         'data/mmd_vae',
                         'mmd_tool_motion2.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)
In [82]:
for zdim in range(64):
    images = v.explore_latent_dimension(datasets['train'][360][0], model, zdim=zdim)
    fig = v.plot_interpolation(images, title='zdim {}'.format(zdim))
In [ ]: